# Load libraries
library(text2vec)
## Warning: package 'text2vec' was built under R version 4.4.3
library(stringdist)
## Warning: package 'stringdist' was built under R version 4.4.2
library(digest)
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(ggplot2)
library(cowplot)
## Warning: package 'cowplot' was built under R version 4.4.2
library(tidyr)
## 
## Attaching package: 'tidyr'
## The following object is masked from 'package:stringdist':
## 
##     extract
library(stringr)
library(plotly)
## 
## Attaching package: 'plotly'
## The following object is masked from 'package:ggplot2':
## 
##     last_plot
## The following object is masked from 'package:stats':
## 
##     filter
## The following object is masked from 'package:graphics':
## 
##     layout
library(RColorBrewer)
#load data
data <- read.csv("df_final.csv")

Part 1 (standardize employment organization names and clean data)

# Workflow

companies <- unique(data$org_name)

# STEP 1: Clean Company Names
clean_company_names <- function(companies) {
  companies_cleaned <- tolower(companies)
  companies_cleaned <- gsub("[[:punct:]]", "", companies_cleaned)
  companies_cleaned <- gsub("\\(.*?\\)", "", companies_cleaned)
  
  filler_words <- c(
    "inc", "llc", "corp", "solutions", "advisors", "financial", 
    "company", "international", "group", "co", "capital", 
    "services", "management", "partners", "software", "tech"
  )
  
  companies_cleaned <- sapply(companies_cleaned, function(name) {
    words <- strsplit(name, "\\s+")[[1]]
    words <- words[!words %in% filler_words]
    paste(words, collapse = " ")
  })
  
  return(companies_cleaned)
}

# STEP 2: Tokenize
word_tokenizer <- function(strings) {
  lapply(strings, function(s) strsplit(s, "\\s+")[[1]])
}

cleaned_names <- clean_company_names(companies)
tokens <- word_tokenizer(cleaned_names)

# STEP 3: Group companies based on core brand
groups <- list()
visited <- rep(FALSE, length(companies))

for (i in seq_along(companies)) {
  if (!visited[i]) {
    group <- c(i)
    visited[i] <- TRUE
    
    if (length(tokens[[i]]) > 0) {
      core_name <- tokens[[i]][1]
      if (i + 1 <= length(companies)) {
        for (j in seq(i + 1, length(companies))) {
          if (!visited[j] && length(tokens[[j]]) > 0 && tokens[[j]][1] == core_name) {
            group <- c(group, j)
            visited[j] <- TRUE
          }
        }
      }
    }
    
    groups <- append(groups, list(group))
  }
}

# STEP 4: Generate group labels and build lookup
group_labels <- c()
group_lookup <- data.frame(company = companies, cleaned = cleaned_names, group_label = NA, stringsAsFactors = FALSE)

for (group_idx in seq_along(groups)) {
  group <- groups[[group_idx]]
  group_tokens <- unlist(tokens[group])
  word_freq <- sort(table(group_tokens), decreasing = TRUE)
  label <- paste(names(head(word_freq, 2)), collapse = " ")
  
  group_labels[group_idx] <- label
  group_lookup$group_label[group] <- label
}

# STEP 5: Assign group labels
data$cleaned_name <- clean_company_names(data$org_name)
data$group_label <- group_lookup$group_label[match(data$cleaned_name, group_lookup$cleaned)]
#Step 6: Final Standardization
data <- data %>%
  mutate(standardized_name = case_when(
    grepl("lynch", group_label, ignore.case = TRUE) ~ "Merrill Lynch",
    grepl("morgan stanley", group_label, ignore.case = TRUE) ~ "Morgan Stanley",
    grepl("citigroup", group_label, ignore.case = TRUE) ~ "Citigroup",
    grepl("wells fargo", group_label, ignore.case = TRUE) ~ "Wells Fargo",
    grepl("edward", group_label, ignore.case = TRUE) ~ "Edward Jones",
    grepl("ameriprise", group_label, ignore.case = TRUE) ~ "Ameriprise",
    grepl("strategic wealth", group_label, ignore.case = TRUE) ~ "Strategic Wealth",
    grepl("ubs", group_label, ignore.case = TRUE) ~ "UBS",
    grepl("lpl", group_label, ignore.case = TRUE) ~ "LPL Financial",
    grepl("equitable", group_label, ignore.case = TRUE) ~ "Equitable",
    grepl("osaic", group_label, ignore.case = TRUE) ~ "Osaic",
    grepl("a better", group_label, ignore.case = TRUE) ~ "A Better Financial",
    grepl("charles", group_label, ignore.case = TRUE) ~ "Charles Schwab",
    grepl("pruco", group_label, ignore.case = TRUE) ~ "Pruco Securities",
    grepl("cetera", group_label, ignore.case = TRUE) ~ "Cetera",
    grepl("banc america", group_label, ignore.case = TRUE) ~ "Bank of America",
    grepl("jp morgan", group_label, ignore.case = TRUE) ~ "JPMorgan",
    grepl("chase", group_label, ignore.case = TRUE) ~ "JPMorgan Chase",
    grepl("raymond james", group_label, ignore.case = TRUE) ~ "Raymond James",
    grepl("mml", group_label, ignore.case = TRUE) ~ "MML Investors",
    grepl("waddell", group_label, ignore.case = TRUE) ~ "Waddell & Reed",
    grepl("^msi$", group_label, ignore.case = TRUE) ~ "MSI Financial",
    grepl("fidelity", group_label, ignore.case = TRUE) ~ "Fidelity",
    grepl("securities america", group_label, ignore.case = TRUE) ~ "Securities America",
    grepl("cambridge", group_label, ignore.case = TRUE) ~ "Cambridge Investment",
    grepl("prudential", group_label, ignore.case = TRUE) ~ "Prudential",
    grepl("td ameritrade", group_label, ignore.case = TRUE) ~ "TD Ameritrade",
    grepl("rbc", group_label, ignore.case = TRUE) ~ "RBC",
    grepl("credit suisse", group_label, ignore.case = TRUE) ~ "Credit Suisse",
    grepl("vanguard", group_label, ignore.case = TRUE) ~ "Vanguard",
    grepl("principal", group_label, ignore.case = TRUE) ~ "Principal",
    grepl("fifth third", group_label, ignore.case = TRUE) ~ "Fifth Third",
    grepl("kestra", group_label, ignore.case = TRUE) ~ "Kestra",
    grepl("truist", group_label, ignore.case = TRUE) ~ "Truist",
    grepl("transamerica", group_label, ignore.case = TRUE) ~ "Transamerica",
    grepl("b riley", group_label, ignore.case = TRUE) ~ "B. Riley",
    grepl("us bancorp", group_label, ignore.case = TRUE) ~ "US Bancorp",
    grepl("usaa", group_label, ignore.case = TRUE) ~ "USAA",
    grepl("stifel", group_label, ignore.case = TRUE) ~ "Stifel",
    grepl("hsbc", group_label, ignore.case = TRUE) ~ "HSBC",
    grepl("ameritas", group_label, ignore.case = TRUE) ~ "Ameritas",
    grepl("fisher", group_label, ignore.case = TRUE) ~ "Fisher Investments",
    grepl("alliancebernstein", group_label, ignore.case = TRUE) ~ "AllianceBernstein",
    grepl("state farm", group_label, ignore.case = TRUE) ~ "State Farm",
    grepl("janney", group_label, ignore.case = TRUE) ~ "Janney",
    grepl("barclays", group_label, ignore.case = TRUE) ~ "Barclays",
    grepl("oppenheimer", group_label, ignore.case = TRUE) ~ "Oppenheimer",
    grepl("tiaacref", group_label, ignore.case = TRUE) ~ "TIAA",
    grepl("blackrock", group_label, ignore.case = TRUE) ~ "BlackRock",
    grepl("bmo", group_label, ignore.case = TRUE) ~ "BMO",
    grepl("t price", group_label, ignore.case = TRUE) ~ "T. Rowe Price",
    TRUE ~ group_label  # fallback to original name
  ))
#Step 8: Clean data
# Drop columns
data <- data %>%
  select(-most_common_org_name, -cleaned_name, -group_label, -reg_location_count)

# Rename 
data <- data %>%
  rename(employment_city_state = most_common_city_state) 

# Update ave_registration_duration
data <- data %>%
  mutate(
    ave_registration_duration = if_else(
      registration_count == 1 & ave_registration_duration == 0, 
      ave_org_duration, 
      ave_registration_duration
    )
  )

# Convert 'ave_registration_duration' from months to years
data <- data %>%
  mutate(ave_registration_duration_years = ave_registration_duration / 12)

# Extract States
data <- data %>%
  mutate(employment_state = str_trim(str_extract(employment_city_state, "[^,]+$")))

# Create a mapping data frame
state_region_df <- data.frame(
  employment_state = state.abb,
  employment_region = state.region,
  stringsAsFactors = FALSE
)

# Join with your original data
data <- data %>%
  left_join(state_region_df, by = "employment_state")

Part 2 (Data exploration)

# Calculate the top 10 most frequent organizations
top_10_orgs <- data %>%
  count(standardized_name) %>%
  arrange(desc(n)) %>%
  top_n(10, n)

ggplot(top_10_orgs, aes(x = reorder(standardized_name, n), y = n)) +
  geom_col(fill = 'lightblue', color = "black") +
  labs(
    title = "Top 10 Most Common Organizations",
    y = "Frequency",
    x = NULL 
  ) +
  theme_classic() +
  theme(axis.text.x = element_text(hjust = 1)) +
  coord_flip() +
  scale_y_continuous(expand = c(0, 1), limits = c(0, NA))  

# Custom theme to remove y-axis elements for boxplot
mytheme <- theme(
  axis.line.y = element_blank(),
  axis.text.y = element_blank(),
  axis.ticks.y = element_blank(),
  axis.title.y = element_blank()
)

# Create histogram plot for 'ave_registration_duration' in years
hist_plot <- ggplot(data, aes(x = ave_registration_duration_years)) +
  geom_histogram(binwidth = 1, color = "black", fill = "lightblue", alpha = 0.7) +
  labs(title = "Histogram of Average Registration Duration (Years)", 
       x = "Average Registration Duration (Years)", y = NULL) +
  theme_classic() +
  theme(
    plot.title = element_text(hjust = 0.5),
    axis.title = element_text(size = 11),
    axis.text = element_text(size = 10),
    plot.margin = margin(10, 10, 10, 10)  # Adjust margin for tighter fit
  ) +
   scale_x_continuous(
    breaks = seq(0, max(data$ave_registration_duration_years, na.rm = TRUE), by = 5)
  ) +
  scale_y_continuous(expand = c(0, 0))     # Remove y-axis gaps

# Create boxplot with space at the bottom for 'ave_registration_duration' in years
box_plot <- ggplot(data, aes(y = ave_registration_duration_years)) +
  geom_boxplot(fill = "lightblue", outlier.shape = 21, outlier.fill = "lightblue", outlier.alpha = 0.5) +
  coord_flip() +  # Flip coordinates to make the boxplot horizontal
  theme_classic() + mytheme +  # Apply custom theme
  theme(
    axis.line.x = element_blank(),
    axis.text.x = element_blank(),
    axis.ticks.x = element_blank(),
    axis.title.x = element_blank()
  )

# Arrange the plots using cowplot
plot_grid(hist_plot, box_plot, ncol = 1, rel_heights = c(0.9, 0.15), align = 'v', axis = 'lr')

# Group by state and find the top 3 most common standardized names
top_3_by_region <- data %>%
  filter(!is.na(employment_region) & !is.na(standardized_name)) %>%  
  group_by(employment_region, standardized_name) %>%
  count(name = "name_count") %>%  
  arrange(desc(name_count)) %>%
  ungroup() %>%  # 💡 Reset group before re-grouping
  group_by(employment_region) %>%  
  slice_head(n = 3)


ggplot(top_3_by_region, aes(x = standardized_name, y = name_count, fill = standardized_name)) +
  geom_bar(stat = "identity", position="dodge", color="black") +
  facet_grid(~ employment_region) +  
  labs(
    title = "Top 3 Most Common Organizations by Employment Region",
    y = NULL,
    fill = "Organization"  
  ) +
  scale_fill_brewer(palette = "Set1") +  
  scale_y_continuous(expand = c(0, 0)) +
  theme_classic() +
  theme(
    axis.text.x = element_blank(),
    axis.ticks.x = element_blank(),
    axis.title.x = element_blank()
  )

suppressMessages({
top_10_orgs_by_reg <- data %>%
  filter(!is.na(employment_region) & !is.na(standardized_name)) %>%
  group_by(standardized_name) %>%
  summarise(total_registration_count = sum(registration_count, na.rm = TRUE)) %>%  # Sum the registration_count
  arrange(desc(total_registration_count)) %>%  # Sort by total_registration_count in descending order
  slice_head(n = 10)  
})

ggplot(top_10_orgs_by_reg, aes(x = standardized_name, y = total_registration_count)) +
  geom_bar(stat = "identity", position = "stack", color = "black", fill = "lightblue") +
  labs(
    title = "Top 10 Org by Registration",
    x = NULL,
    y = "registrations"
  ) +
  theme_classic() +
  scale_y_continuous(expand = c(0, 0)) +
  coord_flip()  # Flip coordinates for better readability of labels

suppressMessages({
#add third variable
top_3_orgs_by_reg_region <- data %>%
  filter(!is.na(employment_region) & !is.na(standardized_name)) %>%
  group_by(employment_region,standardized_name) %>%
  summarise(total_registration_count = sum(registration_count, na.rm = TRUE)) %>%  
  arrange(desc(total_registration_count)) %>% 
  slice_head(n = 3)  

})

# Visualize the top 3 organizations by employment region and another variable
ggplot(top_3_orgs_by_reg_region, aes(x = standardized_name, y = total_registration_count, fill = standardized_name)) +
  geom_bar(stat = "identity", position = "dodge", color = "black") +  
  facet_grid(~ employment_region) +  
  labs(
    title = "Top 3 Orgs by Registrations across Employment Regions",
    fill = "Organization" ,
    y = NULL
  ) +
  scale_fill_brewer(palette = "Set1") +  
  scale_y_continuous(expand = c(0, 0)) +
  theme_classic() +
  theme(
    axis.text.x = element_blank(),  # Remove x-axis labels
    axis.ticks.x = element_blank(),  # Remove x-axis ticks
    axis.title.x = element_blank()  # Remove x-axis title
  )

# Create a scatter plot to visualize registration_count vs ave_registration_duration
ggplot(data, aes(y = registration_count, x = ave_registration_duration)) +
  geom_jitter(color = "lightblue", alpha = 0.5) +  
  labs(
    title = "Registration Count vs Average Registration Duration",
    y = "Registration Count",
    x = "Average Registration Duration (Months)"
  ) +
  theme_classic() +  
  scale_y_continuous(expand = c(0, 0)) +
  theme(
    plot.title = element_text(hjust = 0.5),  
    axis.title = element_text(size = 12),
    axis.text = element_text(size = 10)
  ) 

# Plotly

# 1. Get top 5 firms by unique individual_id
top_10_firms <- data %>%
  filter(!is.na(standardized_name), standardized_name != "") %>%
  group_by(standardized_name) %>%
  summarise(total_individuals = n_distinct(individual_id)) %>%
  arrange(desc(total_individuals)) %>%
  slice_head(n = 5) %>%
  pull(standardized_name)

# 2. Aggregate data by state and firm
top_firm_state_data <- data %>%
  filter(standardized_name %in% top_10_firms) %>%
  group_by(standardized_name, employment_state) %>%
  summarise(
    individual_count = n_distinct(individual_id),
    total_registration_count = sum(registration_count),
    ave_registration_duration = round(mean(ave_registration_duration)),
    total_reg_city_count = sum(reg_city_count),
    .groups = "drop"
  )

# 2. Aggregate data by region and firm
top_firm_region_data <- data %>%
  filter(standardized_name %in% top_10_firms) %>%
  group_by(standardized_name, employment_region) %>%
  summarise(
    individual_count = n_distinct(individual_id),
    total_registration_count = sum(registration_count),
    ave_registration_duration = round(mean(ave_registration_duration)),
    total_reg_city_count = sum(reg_city_count),
    .groups = "drop"
  )
set.seed(42)  # For reproducibility of jitter

# 3. Get state centroids
state_coords <- data.frame(
  employment_state = state.abb,
  lat = state.center$y,
  lon = state.center$x,
  stringsAsFactors = FALSE
)

# 4. Merge and add jitter
top_firm_state_data <- top_firm_state_data %>%
  left_join(state_coords, by = "employment_state") %>%
  mutate(
    lat_jitter = lat + runif(n(), -0.6, 0.6),
    lon_jitter = lon + runif(n(), -0.6, 0.6)
  )

top_firm_state_data <- top_firm_state_data %>%
  mutate(
    lat_jitter = case_when(
      employment_state == "HI" ~ 20.5,
      employment_state == "AK" ~ 63,
      TRUE ~ lat_jitter
    ),
    lon_jitter = case_when(
      employment_state == "HI" ~ -157.5,
      employment_state == "AK" ~ -150,
      TRUE ~ lon_jitter
    )
  )


# 5. Plot
color_palette <- brewer.pal(5, "Set1")

plot_ly(
  data = top_firm_state_data,
  type = "scattergeo",
  mode = "markers",
  lat = ~lat_jitter,
  lon = ~lon_jitter,
  locationmode = "USA-states",
  color = ~standardized_name,
  colors = color_palette,
  text = ~paste0(
    "Firm: ", standardized_name, "<br>",
    "State: ", employment_state, "<br>",
    "Individuals: ", individual_count, "<br>",
    "Registrations: ", total_registration_count, "<br>",
    "Registration average duration in months: ", ave_registration_duration, "<br>",
    "Registered in num Cities: ", total_reg_city_count
  ),
  marker = list(
    size = ~sqrt(individual_count),
    sizemode = "area",
    line = list(width = 0.5, color = "black"),
    opacity = 0.7
  ),
  hoverinfo = "text"
) %>%
  layout(
    title = "Top 5 Firms by IAR Employment",
    geo = list(
      scope = "usa",
      showland = TRUE,
      landcolor = "rgb(240, 240, 240)",
      subunitcolor = "white"
    )
  )